import pytest
from core.infra.database.models.base_model import OrmBase
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import create_database, database_exists


@pytest.fixture(scope='session')
def test_db():
    test_db_url = 'pg+asyncpg://admin:password@127.0.0.1:5432/app_db'
    if not database_exists(test_db_url):
        create_database(test_db_url)

    engine = create_engine(test_db_url)
    base = OrmBase()

    base.metadata.create_all(engine)
    try:
        yield engine
    finally:
        base.metadata.drop_all(engine)


@pytest.fixture(scope='function')
def test_session(test_db):
    connection = test_db.connect()

    trans = connection.begin()
    session = sessionmaker()(bind=connection)

    session.begin_nested()  # SAVEPOINT

    @event.listens_for(session, 'after_transaction_end')
    def restart_savepoint(session, transaction):
        """
        Each time that SAVEPOINT ends, reopen it
        """
        if transaction.nested and not transaction._parent.nested:
            session.begin_nested()

    yield session

    session.close()
    trans.rollback()
    connection.close()
